{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tool Calling and Structured Output\n",
    "\n",
    "In this notebook, we will try to verify the tool calling and structured output capabilities of [DeepSeek V3](https://cloud.siliconflow.cn/open/models?target=deepseek-ai/DeepSeek-V3) and [DeepSeek R1](https://cloud.siliconflow.cn/open/models?target=deepseek-ai/DeepSeek-R1) models from SiliconFlow.\n",
    "\n",
    "We will use both LangChain and LangGraph to test the tool calling and structured output capabilities of the two models.\n",
    "\n",
    "```bash\n",
    "# Get your API Key from: https://cloud.siliconflow.cn/account/ak\n",
    "export DEEPSEEK_API_KEY=<your-siliconcloud-api-key>\n",
    "export DEEPSEEK_API_BASE=https://api.siliconflow.cn/v1\n",
    "```\n",
    "\n",
    "## Defining Schemas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Union\n",
    "from pydantic import BaseModel, Field\n",
    "\n",
    "\n",
    "# Tool Schemas\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers.\n",
    "\n",
    "    Args:\n",
    "        a: First integer\n",
    "        b: Second integer\n",
    "    \"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiply two integers.\n",
    "\n",
    "    Args:\n",
    "        a: First integer\n",
    "        b: Second integer\n",
    "    \"\"\"\n",
    "    return a * b\n",
    "\n",
    "tools = [add, multiply]\n",
    "\n",
    "\n",
    "# Structured Output\n",
    "class Joke(BaseModel):\n",
    "    \"\"\"Joke to tell user.\"\"\"\n",
    "\n",
    "    setup: str = Field(description=\"The setup of the joke\")\n",
    "    punchline: str = Field(description=\"The punchline to the joke\")\n",
    "    rating: Optional[int] = Field(\n",
    "        default=None, description=\"How funny the joke is, from 1 to 10\"\n",
    "    )\n",
    "\n",
    "\n",
    "class ConversationalResponse(BaseModel):\n",
    "    \"\"\"Respond in a conversational manner. Be kind and helpful.\"\"\"\n",
    "\n",
    "    response: str = Field(description=\"A conversational response to the user's query\")\n",
    "\n",
    "\n",
    "class FinalResponse(BaseModel):\n",
    "    final_output: Union[Joke, ConversationalResponse]\n",
    "\n",
    "\n",
    "def print_stream(stream):\n",
    "    for s in stream:\n",
    "        message = s[\"messages\"][-1]\n",
    "        if isinstance(message, tuple):\n",
    "            print(message)\n",
    "        else:\n",
    "            message.pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Test DeepSeek-V3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is 3 * 12?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import init_chat_model\n",
    "\n",
    "# pip install -qU langchain-deepseek\n",
    "llm = init_chat_model(\"deepseek-ai/DeepSeek-V3\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '0195a3675a60caaef4f57654a565fd71', 'function': {'arguments': '{\"a\":3,\"b\":12}', 'name': 'multiply'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 22, 'prompt_tokens': 165, 'total_tokens': 187, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'deepseek-ai/DeepSeek-V3', 'system_fingerprint': '', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-928d89c0-825a-48c2-a798-be5e5d5d71da-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': '0195a3675a60caaef4f57654a565fd71', 'type': 'tool_call'}], usage_metadata={'input_tokens': 165, 'output_tokens': 22, 'total_tokens': 187, 'input_token_details': {}, 'output_token_details': {}})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Tool calling test\n",
    "llm_with_tools = llm.bind_tools(tools)\n",
    "llm_with_tools.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/bb3e22bb-e580-4726-ba3d-3c21796282a1/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FinalResponse(final_output=Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!', rating=None))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Structured output test\n",
    "structured_llm = llm.with_structured_output(FinalResponse)\n",
    "structured_llm.invoke(\"Tell me a joke about cats\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/b5a70e23-ab3f-4f04-9068-620b6efde5d0/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "what is 3 * 12?\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "Tool Calls:\n",
      "  multiply (0195a38e69ff4c957317510af7423d63)\n",
      " Call ID: 0195a38e69ff4c957317510af7423d63\n",
      "  Args:\n",
      "    a: 3\n",
      "    b: 12\n",
      "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
      "Name: multiply\n",
      "\n",
      "36\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "Tool Calls:\n",
      "  multiply (0195a38e94f9b9ab02fa093a145563b9)\n",
      " Call ID: 0195a38e94f9b9ab02fa093a145563b9\n",
      "  Args:\n",
      "    a: 3\n",
      "    b: 12\n",
      "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
      "Name: multiply\n",
      "\n",
      "36\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Hello! How can I assist you today?\n"
     ]
    }
   ],
   "source": [
    "from langgraph.prebuilt import create_react_agent\n",
    "\n",
    "# Tool calling test\n",
    "graph = create_react_agent(llm, tools=tools)\n",
    "\n",
    "inputs = {\"messages\": [(\"user\", \"what is 3 * 12?\")]}\n",
    "print_stream(graph.stream(inputs, stream_mode=\"values\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "❓ LangSmith Trace: https://smith.langchain.com/public/190768d1-ed6e-4713-8613-5c6a1c0ae4c7/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'messages': [HumanMessage(content='Tell me a joke about cats?', additional_kwargs={}, response_metadata={}, id='995fdda3-c9b7-4c00-b1fc-950a3f9966b5'),\n",
       "  AIMessage(content=\"Sure! Here's a joke about cats:\\n\\nWhy was the cat sitting on the computer?\\n\\nBecause it wanted to keep an eye on the mouse! 😹\\n\\nHope that gave you a chuckle! Let me know if you need anything else.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 48, 'prompt_tokens': 176, 'total_tokens': 224, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'deepseek-ai/DeepSeek-V3', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None}, id='run-52c025f8-87ac-4385-b728-9ad07ac0d838-0', usage_metadata={'input_tokens': 176, 'output_tokens': 48, 'total_tokens': 224, 'input_token_details': {}, 'output_token_details': {}})],\n",
       " 'structured_response': FinalResponse(final_output=Joke(setup='Why did the cat sit on the computer?', punchline='To keep an eye on the mouse!', rating=6))}"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph = create_react_agent(\n",
    "    llm,\n",
    "    tools=tools,\n",
    "    # specify the schema for the structured output using `response_format` parameter\n",
    "    response_format=FinalResponse,\n",
    ")\n",
    "\n",
    "inputs = {\"messages\": [(\"user\", \"Tell me a joke about cats?\")]}\n",
    "graph.invoke(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/29bbdccd-98e8-4aaf-a0db-4a092f361c9e/r"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test DeepSeek-R1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import init_chat_model\n",
    "\n",
    "# pip install -qU langchain-deepseek\n",
    "llm = init_chat_model(\"deepseek-ai/DeepSeek-R1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='\\n\\nLet me calculate that for you.\\n\\n**Step 1:** I need to multiply 3 by 12.\\n\\n**Step 2:** Using the multiply function: multiply(a=3, b=12).\\n\\n**Result:** The result of 3 * 12 is 36.\\n\\n`multiply(3, 12) = 36`', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 71, 'prompt_tokens': 168, 'total_tokens': 239, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'deepseek-ai/DeepSeek-R1', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None}, id='run-9cf121ca-14ff-4060-b0b9-7ca8eb9fc017-0', usage_metadata={'input_tokens': 168, 'output_tokens': 71, 'total_tokens': 239, 'input_token_details': {}, 'output_token_details': {}})"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Tool calling test\n",
    "llm_with_tools = llm.bind_tools(tools)\n",
    "llm_with_tools.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/0503ca60-1365-422e-9e78-5b85a92252a4/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FinalResponse(final_output=Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!', rating=8))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Structured output test\n",
    "structured_llm = llm.with_structured_output(FinalResponse)\n",
    "structured_llm.invoke(\"Tell me a joke about cats\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/d849ba89-e77e-4c01-9bba-2ee874814121/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================\u001b[1m Human Message \u001b[0m=================================\n",
      "\n",
      "what is 3 * 12?\n",
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "\n",
      "\n",
      "Alright, let's solve 3 multiplied by 12. I need to calculate this using the `multiply` function. The parameters required are 'a' and 'b', both integers. Here, 3 and 12 are the integers to multiply.\n",
      "\n",
      "So, calling the multiply function with a=3 and b=12. Let me verify the calculation: 3 times 12 equals 36. Yes, that's correct. \n",
      "\n",
      "**Step-by-Step Explanation:**\n",
      "- **Function Used:** `multiply(a: 3, b: 12)`\n",
      "- **Calculation:** 3 * 12 = 36\n",
      "- **Result:** 36\n",
      "\n",
      "**Answer:** The result of 3 multiplied by 12 is \\boxed{36}.\n"
     ]
    }
   ],
   "source": [
    "from langgraph.prebuilt import create_react_agent\n",
    "\n",
    "# Tool calling test\n",
    "graph = create_react_agent(llm, tools=tools)\n",
    "\n",
    "inputs = {\"messages\": [(\"user\", \"what is 3 * 12?\")]}\n",
    "print_stream(graph.stream(inputs, stream_mode=\"values\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "❌ LangSmith Trace (Failed): https://smith.langchain.com/public/efa53517-d2be-434f-912b-eb708f47ffcc/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'messages': [HumanMessage(content='Tell me a joke about cats?', additional_kwargs={}, response_metadata={}, id='52bcebba-d74c-44f8-8272-0318258d2c6f'),\n",
       "  AIMessage(content=\"\\n\\nSure! Here's a light-hearted cat joke for you:\\n\\nWhy don’t cats play poker in the jungle?  \\n…Because there are too many *cheetahs*! 🐾😹\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 43, 'prompt_tokens': 179, 'total_tokens': 222, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'deepseek-ai/DeepSeek-R1', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None}, id='run-2e983ba0-f7ff-48d4-acd0-858b1c923db6-0', usage_metadata={'input_tokens': 179, 'output_tokens': 43, 'total_tokens': 222, 'input_token_details': {}, 'output_token_details': {}})],\n",
       " 'structured_response': FinalResponse(final_output=Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!', rating=None))}"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph = create_react_agent(\n",
    "    llm,\n",
    "    tools=tools,\n",
    "    # specify the schema for the structured output using `response_format` parameter\n",
    "    response_format=FinalResponse,\n",
    ")\n",
    "\n",
    "inputs = {\"messages\": [(\"user\", \"Tell me a joke about cats?\")]}\n",
    "graph.invoke(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "✅ LangSmith Trace: https://smith.langchain.com/public/b13073d6-a96e-4be8-84fc-d361bd8089fd/r"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
